import qlib
from qlib.constant import REG_US, REG_CN
from .data_mgr import DataMgr


def init_qlib(data_dir):
    provider_uri = data_dir
    qlib.init(provider_uri=provider_uri, region=REG_CN)


class QlibMgr:
    def __init__(self, data_dir="./data/cn_data"):
        init_qlib(data_dir)

    def prepare_dataset(self, features, names):
        ds = DataMgr().load_dataset(
            features=(features, names),
        )
        return ds
